{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import gym\n",
    "import numpy as np\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "train = True\n",
    "test = True\n",
    "gamma = 0.9\n",
    "epsilon = 0.1\n",
    "num_episodes = 10000\n",
    "max_steps_per_episode = 200"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "af76a109d37ce49d",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class QNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = torch.nn.Linear(in_features=16, out_features=4, bias=False)\n",
    "        self.initialize_weights()\n",
    "\n",
    "    def initialize_weights(self):\n",
    "        torch.nn.init.uniform_(self.fc1.weight, a=0, b=0.001)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.fc1(x)\n",
    "\n",
    "\n",
    "def to_onehot(i, state_n):\n",
    "    array = np.zeros(state_n, 'uint8')\n",
    "    array[i] = 1\n",
    "    return array\n",
    "\n",
    "\n",
    "env = gym.make('FrozenLake-v1')\n",
    "q_net = QNet()\n",
    "optimizer = optim.Adam(\n",
    "    q_net.parameters(),\n",
    "    lr=0.01,\n",
    "    weight_decay=0.0005\n",
    ")\n",
    "MSE = torch.nn.MSELoss()\n",
    "for episode in range(num_episodes):\n",
    "    # 初始化环境\n",
    "    state, init_info = env.reset()\n",
    "    total_reward = 0\n",
    "\n",
    "    for step in range(max_steps_per_episode):\n",
    "        with torch.no_grad():\n",
    "            action_predict = q_net(torch.tensor([to_onehot(state, 16)], dtype=torch.float32)).detach().numpy()\n",
    "        if np.random.rand(1) < epsilon:\n",
    "            action_predict = action_predict + np.random.randn(1, env.action_space.n) * (1 / (episode + 1))\n",
    "        action = np.argmax(action_predict)\n",
    "\n",
    "        # 执行动作\n",
    "        new_state, reward, done, truncated, info = env.step(action)\n",
    "\n",
    "        # 获取下一步的动作\n",
    "        with torch.no_grad():\n",
    "            action_predict_later = q_net(torch.tensor([to_onehot(new_state, 16)], dtype=torch.float32))\n",
    "\n",
    "        # 在Q-Learning中，策略是贪婪的，所以我们使用“max”来选择下一个动作\n",
    "        max_q = np.max(action_predict_later.detach().numpy())\n",
    "        target_q = action_predict\n",
    "        target_q[0, action] = reward + gamma * max_q\n",
    "\n",
    "        # 更新q_net参数\n",
    "        q_net.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        action_predict = q_net(torch.tensor([to_onehot(state, 16)], dtype=torch.float32))\n",
    "\n",
    "        # 计算损失函数\n",
    "        loss = MSE(action_predict, torch.tensor(target_q, dtype=torch.float32))\n",
    "\n",
    "        # 反向传播\n",
    "        loss.backward()\n",
    "\n",
    "        # 通过优化器更新模型参数\n",
    "        optimizer.step()\n",
    "\n",
    "        state = new_state\n",
    "        total_reward += reward\n",
    "\n",
    "        if done:\n",
    "            break\n",
    "    print(f\"Episode {episode + 1}, Total Reward: {total_reward}\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4fd33291b809d0b5",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "env = gym.make('FrozenLake-v1', render_mode=\"human\")\n",
    "num_episodes_play = 10\n",
    "\n",
    "for _ in range(num_episodes_play):\n",
    "    state, init_info = env.reset()\n",
    "    total_reward = 0\n",
    "\n",
    "    for step in range(max_steps_per_episode):\n",
    "        # 选择最优动作\n",
    "        with torch.no_grad():\n",
    "            action_predict = q_net(torch.tensor([to_onehot(state, 16)], dtype=torch.float32)).numpy()\n",
    "        action = np.argmax(action_predict)\n",
    "\n",
    "        # 执行动作\n",
    "        new_state, reward, done, truncated, info = env.step(action)\n",
    "\n",
    "        state = new_state\n",
    "        total_reward += reward\n",
    "\n",
    "        if done:\n",
    "            break\n",
    "    print(f\"Playing Episode, Total Reward: {total_reward}\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4c8fa9c006a99c7c",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "env.close()  # 关闭环境"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b32385aff12c2b0c",
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
